#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include <sys/types.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <unistd.h>
#include <errno.h>

#include "jszt_cnn.h"
static int golab_flag = 0;
static int CPU_NUMS = 2;
static int run_counts = 1;
void write_to_file(int type, unsigned char * buff, int len) {
    if(golab_flag  == 0) {
        return ;
    } 

    int fd = 0;
    if(type == 0) {
        fd = open("./t_stream.bin", O_RDWR | O_CREAT);
    }

    if(type == 1) {
        fd = open("./t_kerln.bin", O_RDWR | O_CREAT);
    }

    if(type == 2) {
        fd = open("./t_dstream.bin", O_RDWR | O_CREAT);
    }

    if(fd < 0) {
        printf("open file is error %s \n", strerror(errno));
        return;
    }

    int rlen = write(fd, buff, len);

    if(rlen == len) {
        printf("write to file %d successful \n", type);
    } else {
        printf("write to file %d fail, rlen: %d, len: %d \n", type, rlen, len);
    }
}

void test_cnn() {
    // 初始化
    if(JS_OK != jszt_init(CALCULAT_CPU, CPU_NUMS)) {
        return ;
    }   
    
    CNN_PARA cPara = { 0 }; 

    cPara.width = 1024;
    cPara.height = 1024;

    int total_len = cPara.width*cPara.height;
    uint8 *stream = (uint8 *)malloc(total_len);
    uint8 *streamTemp = (uint8 *)malloc(total_len);

    unsigned char *ptr = (unsigned char *)stream;
    int i, j;
    for(i = 0; i < cPara.height; i++) {
        for(j = 0; j < cPara.width; j++) {
            *ptr++ =  rand() % 6; //生成随机数
        }
    }

    // 输出 卷积阵列数据
    write_to_file(0, stream, cPara.width * cPara.height);

    if(NULL == stream) {
        printf("err stream NULL");
        exit(-1);
    }   

    cPara.ckrlen[0] = 9;
    cPara.ckrlen[1] = 0;
    cPara.ckrlen[2] = 0;
    cPara.ckrlen[3] = 0;

    int cnnker_len = cPara.ckrlen[0] + cPara.ckrlen[1] + cPara.ckrlen[2] + cPara.ckrlen[4];
    cnnker_len *= cnnker_len;

    uint8 *cnnker = (uint8 *)malloc(cnnker_len);
    ptr = cnnker;
    for(i = 0; i < cPara.ckrlen[0]; i++) {
        for(j = 0; j < cPara.ckrlen[0]; j++) {
            if(i == j) {
                *ptr = 0x1;
            }
            *ptr++;
        }
    }

    write_to_file(1, cnnker, cPara.ckrlen[0] * cPara.ckrlen[0]);


    if(NULL == cnnker) {
        printf("err dstream NULL");
        exit(-1);
    }

    cPara.stream = stream;
    cPara.cnnker = cnnker;

    printf("width: %d, height: %d, total_len %d \n", cPara.width, cPara.height, total_len);
    printf("cnnker_len: { %d %d %d %d } =>: %d \n", cPara.ckrlen[0],cPara.ckrlen[1],cPara.ckrlen[2],cPara.ckrlen[3], cnnker_len);
    printf("stream %p, dstream: %p, cnnker: %p, cpara: %p \n", cPara.stream, cPara.dstream, cPara.cnnker, &cPara);

    printf("start run \n");

    // 准备工作已经做好, 通知每个核的线程, 开始准备计算
    if(JS_OK != jszt_setReady()) {
        printf("err jszt_setReady \n");
    }

    if(golab_flag == 0) {
        int cnt = run_counts;
        struct timeval tv; 
        long startTimes, endTimes, diff = 0; 
        gettimeofday(&tv, NULL);
        startTimes = tv.tv_sec;

        while(cnt-- > 0) {
            if(JS_OK != jszt_CNN(&cPara)) {
                printf("err jszt_CNN \n");
                // break;
            }
        }
        
        gettimeofday(&tv,NULL);
        endTimes = tv.tv_sec;
        diff = endTimes - startTimes;       // 秒数
        printf("cnt -> %d, diff: %lds \n", run_counts, diff);

    } else {
        if(JS_OK != jszt_CNN(&cPara)) {
            printf("err jszt_CNN \n");
            // break;
        }
    }

    // 开始计算
    // while (xxx)
    // {
        // 初始化对应参数
        if(JS_OK != jszt_CNN(&cPara)) {
            printf("err jszt_CNN \n");
            // break;
        }
    // }

    // 计算结束, 通知每个核的线程, 休眠状态
    if(JS_OK != jszt_setEnd()) {
        printf("err jszt_setEnd \n");
    }

    write_to_file(2, cPara.dstream, cPara.dwidth * cPara.dheight);

    if(JS_OK != jszt_destroy()) {
        printf("err jszt_destroy \n");
    }
    
}



#include <errno.h>
#include <sys/types.h>
#include <fcntl.h>
int main(int argc, char const *argv[]) {
    golab_flag = 1;

    if(argc == 3) {
        golab_flag = 0;        
        CPU_NUMS = atoi(argv[1]);
        run_counts = atoi(argv[2]);
        printf("CPU NUMS: %d, calculate counts: %d \n", CPU_NUMS, run_counts);
    }

    test_cnn();
    return 0;
}
